#!/usr/bin/env python3
"""
Command line inteface for Jinja2.

Copyright (c) 2017, Murray Andrews
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

1.  Redistributions of source code must retain the above copyright notice, this
    list of conditions and the following disclaimer.

2.  Redistributions in binary form must reproduce the above copyright notice,
    this list of conditions and the following disclaimer in the documentation
    and/or other materials provided with the distribution.

3.  Neither the name of the copyright holder nor the names of its contributors
    may be used to endorse or promote products derived from this software
    without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""

from __future__ import annotations

import argparse
import getpass
import json
import os
import sys
from datetime import datetime, timezone
from os.path import basename, splitext
from subprocess import run
from typing import Any

import boto3
import jinja2
import yaml

__author__ = 'Murray Andrews'

PROG = splitext(basename(sys.argv[0]))[0]

# Default path is current dir followed by /usr/local/lib/PROG
DEFAULT_PATH = ':' + os.path.join('/usr/local/lib', PROG)
PATH_ENV_VAR = PROG.upper() + 'PATH'

# Match delimiter options
DELIMS = {'{': '}', '<': '>', '[': ']', '(': ')'}


class Aws:
    """
    Get some AWS related info items + helper functions.

    :param aws_session: A boto3 Session. If not specified a default is created.
    """

    # --------------------------------------------------------------------------
    def __init__(self, aws_session: boto3.Session = None):
        """Init."""

        self.aws_session = aws_session or boto3.Session()
        self._identity = None
        self._ecr_id = None

    # --------------------------------------------------------------------------
    @property
    def identity(self) -> dict[str, Any]:
        """The response to STS get_caller_identity."""

        if not self._identity:
            self._identity = self.aws_session.client('sts').get_caller_identity()
        return self._identity

    # --------------------------------------------------------------------------
    @property
    def account(self) -> str:
        """AWS account ID."""

        return self.identity['Account']

    # --------------------------------------------------------------------------
    @property
    def user(self) -> str:
        """
        AWS user.

        :return:        The AWS user name or role name.
        """

        return self.identity['Arn'].split('/')[1]

    # --------------------------------------------------------------------------
    @property
    def region(self) -> str:
        """AWS region."""

        return self.aws_session.region_name

    # --------------------------------------------------------------------------
    @property
    def ecr_uri(self) -> str:
        """ECR registry base URI."""

        if not self._ecr_id:
            registry_id = self.aws_session.client('ecr').describe_registry()['registryId']
            self._ecr_id = f'{registry_id}.dkr.ecr.{self.region}.amazonaws.com'
        return self._ecr_id

    # --------------------------------------------------------------------------
    def arn(self, service: str, resource: str) -> str:
        """
        Build ARNs for some services.

        This doesn't lookup real resources.

        :param service:     An AWS service indicator (e.g. lambda-function)
        :param resource:    The specific resource identifier.
        :return:            An ARN for the resource. Possibly.
        """

        if service == 'lambda-function':
            return f'arn:aws:lambda:{self.region}:{self.account}:function:{resource}'

        if service == 'log-group':
            return f'arn:aws:logs:{self.region}:{self.account}:log-group:{resource}'

        if service == 'sns-topic':
            return f'arn:aws:sns:{self.region}:{self.account}:{resource}'

        if service == 'sqs-queue':
            return f'arn:aws:sqs:{self.region}:{self.account}:{resource}'

        if service == 'iam-role':
            return f'arn:aws:iam::{self.account}:role/{resource}'

        raise ValueError(f'arn: Unknown service: {service}')


# ..............................................................................
# region built in params
# ..............................................................................


# ------------------------------------------------------------------------------
def _param_datetime():
    """
    Return a dictionary of date/time related params.

    :return:        Date/time related params
    :rtype:         dict[str, str]
    """

    utc_now = datetime.now(timezone.utc)
    return {
        'ctime': str(datetime.now().ctime()),
        'datetime': str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')),
        'utc_ctime': str(utc_now.ctime()),
        'utc_datetime': str(utc_now.strftime('%Y-%m-%d %H:%M:%S')),
        'iso_datetime': str(utc_now.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'),
    }


# ------------------------------------------------------------------------------
def _param_jinja():
    """
    Return parameters about jinja itself.

    :return:        Program name
    :rtype:         dict[str, str]
    """

    return {'prog': PROG}


# ------------------------------------------------------------------------------
def _param_user():
    """
    Return parameters about the current user.

    :return:        User parameters
    :rtype:         dict[str, str]

    """

    return {'user': getpass.getuser()}


# ------------------------------------------------------------------------------
def builtin_params():
    """
    Return a dictionary of all the internally defined params.

    This is done by discovering all functions that have a name of the form
    _param_xxx. Each of which must yield its own dictionary of params.

    :return:    A dictionary of builton parameters.
    :rtype:     dict[str, str]
    """

    gbls = globals()

    params = {}

    for g in gbls:
        if g.startswith('_param_') and callable(gbls[g]):
            params.update((gbls[g])())

    return params


# ..............................................................................
# endregion built in params
# ..............................................................................

# ..............................................................................
# region lava params
# ..............................................................................


def lava_dag(source: str, **kwargs: str) -> str:
    """
    Generate a lava DAG.

    This is a bit of a hack -- it just calls lava-dag-den. Note that this
    returns JSON. This works in a YAML source file because valid JSON is also
    valid YAML. Neat eh?

    :param source:      See lava-dag-gen.
    :param kwargs:      Any of the `--option value` options of lava-dag-gen.
    :return:            The DAG dictionary.
    """

    args = ['lava-dag-gen']
    for option, value in kwargs.items():
        args.extend([f'--{option}', value])
    args.append(source)

    result = run(args, capture_output=True, text=True)
    if result.returncode:
        raise Exception(result.stderr)

    return result.stdout


# ------------------------------------------------------------------------------
def _lava_dag() -> dict[str, Any]:
    """
    Lava specific params.

    :return:    A dictionary containing a reference to the lava dag function.
    """

    return {'dag': lava_dag}


# ------------------------------------------------------------------------------
def _lava_aws():
    return {'aws': Aws()}


# ------------------------------------------------------------------------------
def lava_params() -> dict[str, Any]:
    """
    Return a dictionary of all the internally defined lava params.

    This is done by discovering all functions that have a name of the form
    _lava_xxx. Each of which must yield its own dictionary of params.

    :return:    A dictionary of builton parameters.
    """

    gbls = globals()

    params = {}

    for g in gbls:
        if g.startswith('_lava_') and callable(gbls[g]):
            params.update((gbls[g])())

    return params


# ..............................................................................
# endregion lava params
# ..............................................................................


# ------------------------------------------------------------------------------
class StoreNameValuePair(argparse.Action):
    """
    Store argpare values from options of the form --option name=value.

    The destination (self.dest) will be created as a dict {name: value}. This
    allows multiple name-value pairs to be set for the same option.

    Usage is:

        argparser.add_argument('-x', metavar='key=value', action=StoreNameValuePair)

    or
        argparser.add_argument('-x', metavar='key=value ...', action=StoreNameValuePair,
                               nargs='+')

    """

    # --------------------------------------------------------------------------
    def __call__(self, parser, namespace, values, option_string=None):
        """Handle name=value option."""

        if not hasattr(namespace, self.dest) or not getattr(namespace, self.dest):
            setattr(namespace, self.dest, {})
        argdict = getattr(namespace, self.dest)

        if not isinstance(values, list):
            values = [values]
        for val in values:
            try:
                n, v = val.split('=', 1)
            except ValueError as e:
                raise argparse.ArgumentError(self, str(e))
            argdict[n] = v


# ------------------------------------------------------------------------------
def find_in_path(filename, path):
    """
    Find a file in a path.

    If an absolute path is provided in filename then it is merely checked for
    existence and the absolute path will be returned if it exists. Otherwise the
    path will be searched for the file and if it exists in the path, the
    absolute path will be returned.

    :param filename:    The name of the file to find.
    :param path:        The path. May be either a string of : separated
                        directories or an iterable of dir names. An empty path
                        ('', [] or None) is treated as current directory only.
    :type filename:     str
    :return:            The absolute path of the file if found otherwise None.
    :rtype:             str

    """

    if not filename:
        raise ValueError('filename must be specified')

    if os.path.isabs(filename):
        return filename if os.path.exists(filename) else None

    if not path:
        path = ['.']
    elif isinstance(path, str):
        path = path.split(':')

    for d in path:
        p = os.path.join(d, filename)
        if os.path.exists(p):
            return os.path.abspath(p)

    return None


# ------------------------------------------------------------------------------
def process_cli_args():
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(
        prog=PROG, description='Render a Jinja template with specified parameters.'
    )

    argp.add_argument(
        '-a',
        '--autoescape',
        action='store_true',
        help='Enable Jinja autoescape. By default it is disabled for backward compatibility.',
    )

    argp.add_argument(
        '-B',
        '--no-builtins',
        action='store_false',
        dest='builtins',
        default=True,
        help='Don\'t include any built-in parameters.',
    )

    argp.add_argument(
        '-c',
        '--context-dir',
        metavar='DIR',
        dest='context_dir',
        action='store',
        default='.',
        help=(
            'Set the context directory for the rendering environment.'
            ' This is used to locate any subordinate templates referenced'
            ' via the Jinja "include" directive. Defaults to the current'
            ' directory.'
        ),
    )

    argp.add_argument(
        '-d',
        '--delimiter',
        metavar='CHAR',
        action='store',
        default='{',
        help=(
            'Set the Jinja2 delimiters instead of the default {{...}},'
            ' {%%...%%}, {#...#}. Only the first character is used and'
            ' it will replace the outer curly bracket only. If the'
            ' character selected has a natural left and right variant'
            ' then the obvious pairing is used. Otherwise the same'
            ' character is used for opening and closing. This option'
            ' is useful if the text to be rendered is itself Jinja'
            ' which must be left untouched.'
        ),
    )

    argp.add_argument(
        '-E',
        '--no-environ',
        action='store_false',
        dest='environ',
        default=True,
        help='Don\'t include environment variables in the parameters.',
    )

    argp.add_argument(
        '-f',
        '--file',
        action='append',
        metavar='FILE',
        help=(
            'Get parameters from the specified file. File names  with a .json or'
            ' .jsn suffix are assumed to be in JSON format, otherwise YAML format'
            ' is assumed. Any parameters specified by -p/--param options or'
            ' -l/--list options will override values in the param file. Can be'
            ' specified multiple times. Files are searched for in the path'
            f' specified by the {PATH_ENV_VAR} environment variable.'
        ),
    )

    argp.add_argument(
        '-l',
        '--list',
        action=StoreNameValuePair,
        metavar='name=FILE',
        help=(
            'Get the values of a list parameter from the specified file, one'
            ' value per line. If - then read stdin. Multiple list parameters can'
            ' be specified using multiple -l/--list arguments. Only one file can'
            ' be stdin and then only if the template is not read from stdin.'
            f' Files are searched for in the path specified by the {PATH_ENV_VAR}'
            ' environment variable.'
        ),
    )

    argp.add_argument(
        '-P',
        '--no-path',
        action='store_false',
        dest='use_path',
        default=True,
        help=f'Suppress search for list and parameter files in {PATH_ENV_VAR}.',
    )

    argp.add_argument(
        '-p',
        '--param',
        action=StoreNameValuePair,
        metavar='name=VALUE',
        help=(
            'Set the value of a parameter to be fed in to the Jinja template.'
            ' Multiple parameters can be specified using multiple -p/--param arguments.'
        ),
    )

    argp.add_argument(
        'template',
        action='store',
        nargs='*',
        help=(
            'Name of the file containing the Jinja template. If not'
            ' specified or -, read the template from stdin.'
        ),
    )

    args = argp.parse_args()

    if args.list:
        stdin_used = args.template is sys.stdin
        for f in args.list.values():
            if f == '-':
                if stdin_used:
                    raise Exception('Only one input can be from stdin')
                stdin_used = True

    return args


# ------------------------------------------------------------------------------
def main():
    """
    Do the business.

    :return:        Status
    :rtype:         int
    """

    args = process_cli_args()

    render_params: dict = {}

    path = os.environ.get(PATH_ENV_VAR, DEFAULT_PATH) if args.use_path else None

    # ----------------------------------------
    # Process parameter files -- JSON or YAML. Use the search path.
    if args.file:
        for paramfile in args.file:
            f = find_in_path(paramfile, path) if path else paramfile
            if not f:
                raise Exception(f'No such file or directory: {paramfile}')

            loader = json.load if splitext(f)[1].lower() in ('.json', '.jsn') else yaml.safe_load

            with open(f) as fp:
                render_params.update(loader(fp))

    # ----------------------------------------
    # Process command line specified params
    if args.param:
        render_params.update(args.param)

    # ----------------------------------------
    # Process list files. Use the search path.
    if args.list:
        for k, v in args.list.items():
            if v == '-':
                render_params[k] = [s.strip() for s in sys.stdin.readlines()]
            else:
                f = find_in_path(v, path) if path else v
                if not f:
                    raise Exception(f'No such file or directory: {v}')
                with open(f) as fp:
                    render_params[k] = [s.strip() for s in fp.readlines()]

    if args.builtins:
        render_params.setdefault('jinja', {})
        for k, v in builtin_params().items():
            render_params['jinja'][k] = v
        render_params['jinja']['templates'] = args.template  # noqa
        # Add the lava specific builtins
        render_params.setdefault('lava', {})
        for k, v in lava_params().items():
            render_params['lava'][k] = v

    if args.environ:
        render_params['environ'] = os.environ

    # ----------------------------------------
    # Setup a Jinja2 environment
    j_start = args.delimiter[0]
    j_end = DELIMS.get(j_start, j_start)
    j_env = jinja2.Environment(
        loader=jinja2.FileSystemLoader(args.context_dir),
        block_start_string=j_start + '%',
        block_end_string='%' + j_end,
        variable_start_string=j_start + '{',
        variable_end_string='}' + j_end,
        comment_start_string=j_start + '#',
        comment_end_string='#' + j_end,
        autoescape=args.autoescape,  # noqa: S701
    )

    # ----------------------------------------
    # Read the templates and render them.
    if not args.template:
        # Stdin
        args.template = ['-']

    for template in args.template:
        try:
            if template == '-':
                jt = j_env.from_string(sys.stdin.read())
            else:
                f = find_in_path(template, path) if path else template
                if not f:
                    raise Exception(f'{template}: No such file or directory')
                with open(f) as fp:
                    jt = j_env.from_string(fp.read())
            print(jt.render(**render_params))
        except jinja2.TemplateNotFound as e:
            raise Exception(f'{template}: template {e} not found')
        except Exception as e:
            raise Exception(f'{template}: {e}')


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        main()
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)

    exit(0)
